{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "dc29a353-e0ea-42ed-bd17-d9efb998b8be",
   "metadata": {},
   "source": [
    "# Split Learning for Graph Neural Network"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "34cc94bb",
   "metadata": {},
   "source": [
    ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8e4b05c2-ac8a-4436-9abe-95a335c34947",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Create two participant alice and bob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "67a0f4bb-8522-4b91-8fd3-28dd7dee0641",
   "metadata": {},
   "outputs": [],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you got a running secetflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(parties=['alice', 'bob'], address='local')\n",
    "\n",
    "alice, bob = sf.PYU('alice'), sf.PYU('bob')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8d589dac-2670-40fa-b33c-4abb86db63b2",
   "metadata": {},
   "source": [
    "## Prepare the Dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "52b1f363-ff2e-4d89-a3d6-7fb2e9c69f83",
   "metadata": {},
   "source": [
    "### The cora dataset\n",
    "The [cora](https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz) dataset has two tap-separated files: `cora.cites` and `cora.content`.\n",
    "\n",
    "- The `cora.cites` includes the citation records with two columns: cited_paper_id (target) and citing_paper_id (source).\n",
    "- The `cora.content` includes the paper content records with 1,435 columns: paper_id, subject, and 1,433 binary features.\n",
    "\n",
    "Let us use the partitioned cora dataset, which is already a built-in dataset of SecretFlow.\n",
    "\n",
    "- The train set includes 140 cited_paper_ids.\n",
    "- The test set includes 1000 cited_paper_ids.\n",
    "- The valid set includes 500 cited_paper_ids."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4b5db30b-0ed7-4cf4-b2bb-b887a9d3fb3f",
   "metadata": {},
   "source": [
    "### Split the dataset\n",
    "\n",
    "Let us split the dataset for split learning.\n",
    "\n",
    "- Alice holds the 1~716 features, and bob holds the left.\n",
    "- Alice holds all label.\n",
    "- Alice and bob hold all edges both."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ff444d11-d754-43cc-83d9-e11d41bb0a50",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_754102/2692670243.py:27: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.\n",
      "  objects.append(pickle.load(f, encoding='latin1'))\n",
      "/tmp/ipykernel_754102/2692670243.py:38: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.\n",
      "  edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))\n"
     ]
    }
   ],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import scipy\n",
    "import zipfile\n",
    "import tempfile\n",
    "from pathlib import Path\n",
    "from secretflow.utils.simulation.datasets import dataset\n",
    "from secretflow.data.ndarray import load\n",
    "\n",
    "\n",
    "def load_cora():\n",
    "    dataset_zip = dataset('cora')\n",
    "    extract_path = str(Path(dataset_zip).parent)\n",
    "    with zipfile.ZipFile(dataset_zip, 'r') as zip_f:\n",
    "        zip_f.extractall(extract_path)\n",
    "\n",
    "    file_names = [\n",
    "        os.path.join(extract_path, f'ind.cora.{name}')\n",
    "        for name in ['y', 'tx', 'ty', 'allx', 'ally', 'graph']\n",
    "    ]\n",
    "\n",
    "    objects = []\n",
    "    for name in file_names:\n",
    "        with open(name, 'rb') as f:\n",
    "            objects.append(pickle.load(f, encoding='latin1'))\n",
    "\n",
    "    y, tx, ty, allx, ally, graph = tuple(objects)\n",
    "\n",
    "    with open(os.path.join(extract_path, f\"ind.cora.test.index\"), 'r') as f:\n",
    "        test_idx_reorder = f.readlines()\n",
    "    test_idx_reorder = list(map(lambda s: int(s.strip()), test_idx_reorder))\n",
    "    test_idx_range = np.sort(test_idx_reorder)\n",
    "\n",
    "    nodes = scipy.sparse.vstack((allx, tx)).tolil()\n",
    "    nodes[test_idx_reorder, :] = nodes[test_idx_range, :]\n",
    "    edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))\n",
    "    edge = edge.toarray() + np.eye(edge.shape[1])\n",
    "\n",
    "    labels = np.vstack((ally, ty))\n",
    "    labels[test_idx_reorder, :] = labels[test_idx_range, :]\n",
    "\n",
    "    idx_test = test_idx_range.tolist()\n",
    "    idx_train = range(len(y))\n",
    "    idx_val = range(len(y), len(y) + 500)\n",
    "\n",
    "    def sample_mask(idx, length):\n",
    "        mask = np.zeros(length)\n",
    "        mask[idx] = 1\n",
    "        return np.array(mask, dtype=bool)\n",
    "\n",
    "    idx_train = sample_mask(idx_train, labels.shape[0])\n",
    "    idx_val = sample_mask(idx_val, labels.shape[0])\n",
    "    idx_test = sample_mask(idx_test, labels.shape[0])\n",
    "\n",
    "    y_train = np.zeros(labels.shape)\n",
    "    y_val = np.zeros(labels.shape)\n",
    "    y_test = np.zeros(labels.shape)\n",
    "    y_train[idx_train, :] = labels[idx_train, :]\n",
    "    y_val[idx_val, :] = labels[idx_val, :]\n",
    "    y_test[idx_test, :] = labels[idx_test, :]\n",
    "\n",
    "    nodes = nodes.toarray()\n",
    "    features_split_pos = round(nodes.shape[1] / 2)\n",
    "    nodes_alice, nodes_bob = (\n",
    "        nodes[:, :features_split_pos],\n",
    "        nodes[:, features_split_pos:],\n",
    "    )\n",
    "    temp_dir = tempfile.mkdtemp()\n",
    "    saved_files = [\n",
    "        os.path.join(temp_dir, name)\n",
    "        for name in [\n",
    "            'edge.npy',\n",
    "            'x_alice.npy',\n",
    "            'x_bob.npy',\n",
    "            'y_train.npy',\n",
    "            'y_val.npy',\n",
    "            'y_test.npy',\n",
    "            'idx_train.npy',\n",
    "            'idx_val.npy',\n",
    "            'idx_test.npy',\n",
    "        ]\n",
    "    ]\n",
    "    np.save(saved_files[0], edge)\n",
    "    np.save(saved_files[1], nodes_alice)\n",
    "    np.save(saved_files[2], nodes_bob)\n",
    "    np.save(saved_files[3], y_train)\n",
    "    np.save(saved_files[4], y_val)\n",
    "    np.save(saved_files[5], y_test)\n",
    "    np.save(saved_files[6], idx_train)\n",
    "    np.save(saved_files[7], idx_val)\n",
    "    np.save(saved_files[8], idx_test)\n",
    "    return saved_files\n",
    "\n",
    "\n",
    "saved_files = load_cora()\n",
    "\n",
    "edge = load({alice: saved_files[0], bob: saved_files[0]})\n",
    "features = load({alice: saved_files[1], bob: saved_files[2]})\n",
    "Y_train = load({alice: saved_files[3]})\n",
    "Y_val = load({alice: saved_files[4]})\n",
    "Y_test = load({alice: saved_files[5]})\n",
    "idx_train = load({alice: saved_files[6]})\n",
    "idx_val = load({alice: saved_files[7]})\n",
    "idx_test = load({alice: saved_files[8]})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "acc9a773",
   "metadata": {},
   "source": [
    "By the way, since cora is a built-in dataset of SecretFlow, you can just run the following snippet to replace the codes above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50100f9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.utils.simulation.datasets import load_cora\n",
    "\n",
    "(edge, features, Y_train, Y_val, Y_test, idx_train, idx_val, idx_test) = load_cora(\n",
    "    [alice, bob]\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fd373dcf",
   "metadata": {},
   "source": [
    "## Build a Graph Neural Network Model\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f5250d2c",
   "metadata": {},
   "source": [
    "### Implement a graph convolution layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "02de8521",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import activations\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import constraints, initializers, regularizers\n",
    "from tensorflow.keras.layers import Dropout, Layer, LeakyReLU\n",
    "\n",
    "\n",
    "class GraphAttention(Layer):\n",
    "    def __init__(\n",
    "        self,\n",
    "        F_,\n",
    "        attn_heads=1,\n",
    "        attn_heads_reduction='average',  # {'concat', 'average'}\n",
    "        dropout_rate=0.5,\n",
    "        activation='relu',\n",
    "        use_bias=True,\n",
    "        kernel_initializer='glorot_uniform',\n",
    "        bias_initializer='zeros',\n",
    "        attn_kernel_initializer='glorot_uniform',\n",
    "        kernel_regularizer=None,\n",
    "        bias_regularizer=None,\n",
    "        attn_kernel_regularizer=None,\n",
    "        activity_regularizer=None,\n",
    "        kernel_constraint=None,\n",
    "        bias_constraint=None,\n",
    "        attn_kernel_constraint=None,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        if attn_heads_reduction not in {'concat', 'average'}:\n",
    "            raise ValueError('Possbile reduction methods: concat, average')\n",
    "\n",
    "        self.F_ = F_  # Number of output features (F' in the paper)\n",
    "        self.attn_heads = attn_heads  # Number of attention heads (K in the paper)\n",
    "        self.attn_heads_reduction = attn_heads_reduction\n",
    "        self.dropout_rate = dropout_rate  # Internal dropout rate\n",
    "        self.activation = activations.get(activation)\n",
    "        self.use_bias = use_bias\n",
    "\n",
    "        self.kernel_initializer = initializers.get(kernel_initializer)\n",
    "        self.bias_initializer = initializers.get(bias_initializer)\n",
    "        self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)\n",
    "\n",
    "        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n",
    "        self.bias_regularizer = regularizers.get(bias_regularizer)\n",
    "        self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)\n",
    "        self.activity_regularizer = regularizers.get(activity_regularizer)\n",
    "\n",
    "        self.kernel_constraint = constraints.get(kernel_constraint)\n",
    "        self.bias_constraint = constraints.get(bias_constraint)\n",
    "        self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)\n",
    "        self.supports_masking = False\n",
    "\n",
    "        # Populated by build()\n",
    "        self.kernels = []  # Layer kernels for attention heads\n",
    "        self.biases = []  # Layer biases for attention heads\n",
    "        self.attn_kernels = []  # Attention kernels for attention heads\n",
    "\n",
    "        if attn_heads_reduction == 'concat':\n",
    "            # Output will have shape (..., K * F')\n",
    "            self.output_dim = self.F_ * self.attn_heads\n",
    "        else:\n",
    "            # Output will have shape (..., F')\n",
    "            self.output_dim = self.F_\n",
    "\n",
    "        super(GraphAttention, self).__init__(**kwargs)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        assert len(input_shape) >= 2\n",
    "        F = input_shape[0][-1]\n",
    "\n",
    "        # Initialize weights for each attention head\n",
    "        for head in range(self.attn_heads):\n",
    "            # Layer kernel\n",
    "            kernel = self.add_weight(\n",
    "                shape=(F, self.F_),\n",
    "                initializer=self.kernel_initializer,\n",
    "                regularizer=self.kernel_regularizer,\n",
    "                constraint=self.kernel_constraint,\n",
    "                name='kernel_{}'.format(head),\n",
    "            )\n",
    "            self.kernels.append(kernel)\n",
    "\n",
    "            # # Layer bias\n",
    "            if self.use_bias:\n",
    "                bias = self.add_weight(\n",
    "                    shape=(self.F_,),\n",
    "                    initializer=self.bias_initializer,\n",
    "                    regularizer=self.bias_regularizer,\n",
    "                    constraint=self.bias_constraint,\n",
    "                    name='bias_{}'.format(head),\n",
    "                )\n",
    "                self.biases.append(bias)\n",
    "\n",
    "            # Attention kernels\n",
    "            attn_kernel_self = self.add_weight(\n",
    "                shape=(self.F_, 1),\n",
    "                initializer=self.attn_kernel_initializer,\n",
    "                regularizer=self.attn_kernel_regularizer,\n",
    "                constraint=self.attn_kernel_constraint,\n",
    "                name='attn_kernel_self_{}'.format(head),\n",
    "            )\n",
    "            attn_kernel_neighs = self.add_weight(\n",
    "                shape=(self.F_, 1),\n",
    "                initializer=self.attn_kernel_initializer,\n",
    "                regularizer=self.attn_kernel_regularizer,\n",
    "                constraint=self.attn_kernel_constraint,\n",
    "                name='attn_kernel_neigh_{}'.format(head),\n",
    "            )\n",
    "            self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])\n",
    "        self.built = True\n",
    "\n",
    "    def call(self, inputs):\n",
    "        X = inputs[0]  # Node features (N x F)\n",
    "        A = inputs[1]  # Adjacency matrix (N x N)\n",
    "\n",
    "        outputs = []\n",
    "        for head in range(self.attn_heads):\n",
    "            kernel = self.kernels[head]  # W in the paper (F x F')\n",
    "            attention_kernel = self.attn_kernels[\n",
    "                head\n",
    "            ]  # Attention kernel a in the paper (2F' x 1)\n",
    "\n",
    "            # Compute inputs to attention network\n",
    "            features = K.dot(X, kernel)  # (N x F')\n",
    "\n",
    "            # Compute feature combinations\n",
    "            # Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]\n",
    "            attn_for_self = K.dot(\n",
    "                features, attention_kernel[0]\n",
    "            )  # (N x 1), [a_1]^T [Wh_i]\n",
    "            attn_for_neighs = K.dot(\n",
    "                features, attention_kernel[1]\n",
    "            )  # (N x 1), [a_2]^T [Wh_j]\n",
    "\n",
    "            # Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]\n",
    "            dense = attn_for_self + K.transpose(\n",
    "                attn_for_neighs\n",
    "            )  # (N x N) via broadcasting\n",
    "\n",
    "            # Add nonlinearty\n",
    "            dense = LeakyReLU(alpha=0.2)(dense)\n",
    "\n",
    "            # Mask values before activation (Vaswani et al., 2017)\n",
    "            mask = -10e9 * (1.0 - A)\n",
    "            dense += mask\n",
    "\n",
    "            # Apply softmax to get attention coefficients\n",
    "            dense = K.softmax(dense)  # (N x N)\n",
    "\n",
    "            # Apply dropout to features and attention coefficients\n",
    "            dropout_attn = Dropout(self.dropout_rate)(dense)  # (N x N)\n",
    "            dropout_feat = Dropout(self.dropout_rate)(features)  # (N x F')\n",
    "\n",
    "            # Linear combination with neighbors' features\n",
    "            node_features = K.dot(dropout_attn, dropout_feat)  # (N x F')\n",
    "\n",
    "            if self.use_bias:\n",
    "                node_features = K.bias_add(node_features, self.biases[head])\n",
    "\n",
    "            # Add output of attention head to final output\n",
    "            outputs.append(node_features)\n",
    "\n",
    "        # Aggregate the heads' output according to the reduction method\n",
    "        if self.attn_heads_reduction == 'concat':\n",
    "            output = K.concatenate(outputs)  # (N x KF')\n",
    "        else:\n",
    "            output = K.mean(K.stack(outputs), axis=0)  # N x F')\n",
    "\n",
    "        output = self.activation(output)\n",
    "        return output\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        output_shape = input_shape[0][0], self.output_dim\n",
    "        return output_shape\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config().copy()\n",
    "        config.update(\n",
    "            {\n",
    "                'attn_heads': self.attn_heads,\n",
    "                'attn_heads_reduction': self.attn_heads_reduction,\n",
    "                'F_': self.F_,\n",
    "            }\n",
    "        )\n",
    "        return config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "27038f11",
   "metadata": {},
   "source": [
    "### Implement a fuse net\n",
    "The fuse model is used in the party with the label. It works as follows:\n",
    "1. Use the concated node embeddings to generat the final node embeddings.\n",
    "2. Feed the node embeddings in a Softmax layer to predict the node class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "543b6a6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ServerNet(tf.keras.layers.Layer):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channel: int,\n",
    "        hidden_size: int,\n",
    "        num_layer: int,\n",
    "        num_class: int,\n",
    "        dropout: float,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        super(ServerNet, self).__init__()\n",
    "        self.num_class = num_class\n",
    "        self.num_layer = num_layer\n",
    "        self.hidden_size = hidden_size\n",
    "        self.in_channel = in_channel\n",
    "        self.dropout = dropout\n",
    "        self.layers = []\n",
    "        super(ServerNet, self).__init__(**kwargs)\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.layers.append(\n",
    "            tf.keras.layers.Dense(self.hidden_size, input_shape=(self.in_channel,))\n",
    "        )\n",
    "        for i in range(self.num_layer - 2):\n",
    "            self.layers.append(\n",
    "                tf.keras.layers.Dense(self.hidden_size, input_shape=(self.hidden_size,))\n",
    "            )\n",
    "        self.layers.append(\n",
    "            tf.keras.layers.Dense(self.num_class, input_shape=(self.hidden_size,))\n",
    "        )\n",
    "\n",
    "        super(ServerNet, self).build(input_shape)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = inputs\n",
    "        x = Dropout(self.dropout)(x)\n",
    "        for i in range(self.num_layer):\n",
    "            x = Dropout(self.dropout)(x)\n",
    "            x = self.layers[i](x)\n",
    "\n",
    "        return K.softmax(x)\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        output_shape = self.hidden_size, self.output_dim\n",
    "        return output_shape\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config().copy()\n",
    "        config.update(\n",
    "            {\n",
    "                'in_channel': self.in_channel,\n",
    "                'hidden_size': self.hidden_size,\n",
    "                'num_layer': self.num_layer,\n",
    "                'num_class': self.num_class,\n",
    "                'dropout': self.dropout,\n",
    "            }\n",
    "        )\n",
    "        return config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ec3ef65f",
   "metadata": {},
   "source": [
    "### Build the base model\n",
    "The base model is used in each party to generate node embeddings. It applys one graph convolutional layer to produce node embeddings.\n",
    "\n",
    "The node embeddings of all parties are then transfered to the party with labels for further processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "64356e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Model\n",
    "\n",
    "\n",
    "def create_base_model(\n",
    "    input_shape, n_hidden, l2_reg, num_heads, dropout_rate, learning_rate\n",
    "):\n",
    "    def base_model():\n",
    "        feature_input = tf.keras.Input(shape=(input_shape[1],))\n",
    "        graph_input = tf.keras.Input(shape=(input_shape[0],))\n",
    "        regular = tf.keras.regularizers.l2(l2_reg)\n",
    "        outputs = GraphAttention(\n",
    "            F_=n_hidden,\n",
    "            attn_heads=num_heads,\n",
    "            attn_heads_reduction='average',  # {'concat', 'average'}\n",
    "            dropout_rate=dropout_rate,\n",
    "            activation='relu',\n",
    "            use_bias=True,\n",
    "            kernel_initializer='glorot_uniform',\n",
    "            bias_initializer='zeros',\n",
    "            attn_kernel_initializer='glorot_uniform',\n",
    "            kernel_regularizer=regular,\n",
    "            bias_regularizer=None,\n",
    "            attn_kernel_regularizer=None,\n",
    "            activity_regularizer=None,\n",
    "            kernel_constraint=None,\n",
    "            bias_constraint=None,\n",
    "            attn_kernel_constraint=None,\n",
    "        )([feature_input, graph_input])\n",
    "        # outputs = tf.keras.layers.Flatten()(outputs)\n",
    "        model = Model(inputs=[feature_input, graph_input], outputs=outputs)\n",
    "        model._name = \"embed_model\"\n",
    "        # Compile model\n",
    "        model.summary()\n",
    "        metrics = ['acc']\n",
    "        optimizer = tf.keras.optimizers.get(\n",
    "            {\n",
    "                'class_name': 'adam',\n",
    "                'config': {'learning_rate': learning_rate},\n",
    "            }\n",
    "        )\n",
    "        model.compile(\n",
    "            loss='categorical_crossentropy',\n",
    "            weighted_metrics=metrics,\n",
    "            optimizer=optimizer,\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    return base_model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3300e7c8",
   "metadata": {},
   "source": [
    "### Build the fuse model\n",
    "The fuse model concat the node embeddings from all parties as input. It works only in the party with the label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a8a58cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "\n",
    "def create_fuse_model(hidden_units, hidden_size, n_classes, layer_num, learning_rate):\n",
    "    def fuse_model():\n",
    "        inputs = [keras.Input(shape=size) for size in hidden_units]\n",
    "        x = layers.concatenate(inputs)\n",
    "        input_shape = x.shape[-1]\n",
    "        y_pred = ServerNet(\n",
    "            in_channel=input_shape,\n",
    "            hidden_size=hidden_size,\n",
    "            num_layer=layer_num,\n",
    "            num_class=n_classes,\n",
    "            dropout=0.0,\n",
    "        )(x)\n",
    "        # Create the model.\n",
    "        model = keras.Model(inputs=inputs, outputs=y_pred, name=\"fuse_model\")\n",
    "        model.summary()\n",
    "        metrics = ['acc']\n",
    "        optimizer = tf.keras.optimizers.get(\n",
    "            {\n",
    "                'class_name': 'adam',\n",
    "                'config': {'learning_rate': learning_rate},\n",
    "            }\n",
    "        )\n",
    "        model.compile(\n",
    "            loss='categorical_crossentropy',\n",
    "            weighted_metrics=metrics,\n",
    "            optimizer=optimizer,\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    return fuse_model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d25e7182",
   "metadata": {},
   "source": [
    "## Train GNN model with split learning\n",
    "\n",
    "Let us build a split learning model for training. \n",
    "\n",
    "Alice who has the label holds a base model and a fuse model, while bob holds a base model only.\n",
    "\n",
    "The whole model structure is as follow\n",
    "\n",
    "<img alt=\"split_learning_gnn_model.png\" src=\"resources/split_gnn.svg\" width=\"400\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "df19c1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.ml.nn import SLModel\n",
    "\n",
    "hidden_size = 256\n",
    "n_classes = 7\n",
    "attn_heads = 2\n",
    "layer_num = 3\n",
    "learning_rate = 1e-3\n",
    "dropout_rate = 0.0\n",
    "l2_reg = 0.1\n",
    "num_heads = 4\n",
    "epochs = 10\n",
    "optimizer = 'adam'\n",
    "\n",
    "partition_shapes = features.partition_shape()\n",
    "\n",
    "input_shape_alice = partition_shapes[alice]\n",
    "input_shape_bob = partition_shapes[bob]\n",
    "\n",
    "sl_model = SLModel(\n",
    "    base_model_dict={\n",
    "        alice: create_base_model(\n",
    "            input_shape_alice,\n",
    "            hidden_size,\n",
    "            l2_reg,\n",
    "            num_heads,\n",
    "            dropout_rate,\n",
    "            learning_rate,\n",
    "        ),\n",
    "        bob: create_base_model(\n",
    "            input_shape_bob,\n",
    "            hidden_size,\n",
    "            l2_reg,\n",
    "            num_heads,\n",
    "            dropout_rate,\n",
    "            learning_rate,\n",
    "        ),\n",
    "    },\n",
    "    device_y=alice,\n",
    "    model_fuse=create_fuse_model(\n",
    "        [hidden_size, hidden_size], hidden_size, n_classes, layer_num, learning_rate\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d8fe270c",
   "metadata": {},
   "source": [
    "Fit the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "078727f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:05<00:00,  5.67s/it, epoch: 0/10 -  train_loss:0.10079389065504074  train_acc:0.12857143580913544  val_loss:0.35441863536834717  val_acc:0.3880000114440918 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.02s/it, epoch: 1/10 -  train_loss:0.2256714105606079  train_acc:0.44843751192092896  val_loss:0.3481321930885315  val_acc:0.5640000104904175 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.01s/it, epoch: 2/10 -  train_loss:0.22043751180171967  train_acc:0.637499988079071  val_loss:0.34046509861946106  val_acc:0.6320000290870667 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.09s/it, epoch: 3/10 -  train_loss:0.21422825753688812  train_acc:0.703125  val_loss:0.33100318908691406  val_acc:0.6539999842643738 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 4/10 -  train_loss:0.20669467747211456  train_acc:0.7250000238418579  val_loss:0.3193798065185547  val_acc:0.6840000152587891 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.10s/it, epoch: 5/10 -  train_loss:0.19755281507968903  train_acc:0.7484375238418579  val_loss:0.30531173944473267  val_acc:0.7080000042915344 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 6/10 -  train_loss:0.1866169422864914  train_acc:0.7671874761581421  val_loss:0.28866109251976013  val_acc:0.7279999852180481 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 7/10 -  train_loss:0.17386208474636078  train_acc:0.7828124761581421  val_loss:0.26949769258499146  val_acc:0.7319999933242798 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 8/10 -  train_loss:0.1594790667295456  train_acc:0.785937488079071  val_loss:0.24824994802474976  val_acc:0.7319999933242798 ]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.06s/it, epoch: 9/10 -  train_loss:0.1439441591501236  train_acc:0.785937488079071  val_loss:0.2257150411605835  val_acc:0.7400000095367432 ]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'train_loss': [0.10079389,\n",
       "  0.22567141,\n",
       "  0.22043751,\n",
       "  0.21422826,\n",
       "  0.20669468,\n",
       "  0.19755282,\n",
       "  0.18661694,\n",
       "  0.17386208,\n",
       "  0.15947907,\n",
       "  0.14394416],\n",
       " 'train_acc': [0.12857144,\n",
       "  0.4484375,\n",
       "  0.6375,\n",
       "  0.703125,\n",
       "  0.725,\n",
       "  0.7484375,\n",
       "  0.7671875,\n",
       "  0.7828125,\n",
       "  0.7859375,\n",
       "  0.7859375],\n",
       " 'val_loss': [0.35441864,\n",
       "  0.3481322,\n",
       "  0.3404651,\n",
       "  0.3310032,\n",
       "  0.3193798,\n",
       "  0.30531174,\n",
       "  0.2886611,\n",
       "  0.2694977,\n",
       "  0.24824995,\n",
       "  0.22571504],\n",
       " 'val_acc': [0.388,\n",
       "  0.564,\n",
       "  0.632,\n",
       "  0.654,\n",
       "  0.684,\n",
       "  0.708,\n",
       "  0.728,\n",
       "  0.732,\n",
       "  0.732,\n",
       "  0.74]}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sl_model.fit(\n",
    "    x=[features, edge],\n",
    "    y=Y_train,\n",
    "    epochs=epochs,\n",
    "    batch_size=input_shape_alice[0],\n",
    "    sample_weight=idx_train,\n",
    "    validation_data=([features, edge], Y_val, idx_val),\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "09de4553",
   "metadata": {},
   "source": [
    "Examine the GNN model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6dce891f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate Processing:: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s, loss:0.4411766827106476 acc:0.7720000147819519]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'loss': 0.44117668, 'acc': 0.772}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sl_model.evaluate(\n",
    "    x=[features, edge],\n",
    "    y=Y_test,\n",
    "    batch_size=input_shape_alice[0],\n",
    "    sample_weight=idx_test,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b357a1ff-9674-4258-87d7-9278d3c0406d",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ceb119b1-dd74-4ad1-9c00-69b0928e125f",
   "metadata": {},
   "source": [
    "In this tutorial, we demonstrate how to train graph neural network with split learning. This is a very basic implementation and there are some works we will explore in the future:\n",
    "\n",
    "- SGD on large graphs: in the example above, in each training step, we have to do `prepare`, `aggregate` and `update` on whole graph, which is extremely computation intensive. We should perform stochastic minibatch training to reduce computation and memory comsumption.\n",
    "- Partially aligned graphs: in the example above, parties must have same nodes set which may not be satisfied in real cases. We want to explore the case where all parties have common subset nodes."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('3.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "ae1fdd5fd034b7d694352220485921694ff89198520409089b4646721fce11ca"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
